# Ultralytics YOLO 🚀, AGPL-3.0 license

import contextlib
import math
import re
import time

import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision

from ultralytics.yolo.utils import LOGGER

from .metrics import box_iou


class Profile(contextlib.ContextDecorator):
    """
    YOLOv8 Profile class.
    Usage: as a decorator with @Profile() or as a context manager with 'with Profile():'
    """

    def __init__(self, t=0.0):
        """
        Initialize the Profile class.

        Args:
            t (float): Initial time. Defaults to 0.0.
        """
        self.t = t
        self.cuda = torch.cuda.is_available()

    def __enter__(self):
        """
        Start timing.
        """
        self.start = self.time()
        return self

    def __exit__(self, type, value, traceback):
        """
        Stop timing.
        """
        self.dt = self.time() - self.start  # delta-time
        self.t += self.dt  # accumulate dt

    def time(self):
        """
        Get current time.
        """
        if self.cuda:
            torch.cuda.synchronize()
        return time.time()


def coco80_to_coco91_class():  # converts 80-index (val2014) to 91-index (paper)
    # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
    # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
    # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
    # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)]  # darknet to coco
    # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)]  # coco to darknet
    return [
        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
        35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
        64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]


def segment2box(segment, width=640, height=640):
    """
    Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)

    Args:
      segment (torch.Tensor): the segment label
      width (int): the width of the image. Defaults to 640
      height (int): The height of the image. Defaults to 640

    Returns:
      (np.ndarray): the minimum and maximum x and y values of the segment.
    """
    # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
    x, y = segment.T  # segment xy
    inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
    x, y, = x[inside], y[inside]
    return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(
        4, dtype=segment.dtype)  # xyxy


def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
    """
    Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
    (img1_shape) to the shape of a different image (img0_shape).

    Args:
      img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
      boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
      img0_shape (tuple): the shape of the target image, in the format of (height, width).
      ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
                         calculated based on the size difference between the two images.

    Returns:
      boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
    """
    if ratio_pad is None:  # calculate from img0_shape
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
        pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
            (img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1)  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]

    boxes[..., [0, 2]] -= pad[0]  # x padding
    boxes[..., [1, 3]] -= pad[1]  # y padding
    boxes[..., :4] /= gain
    clip_boxes(boxes, img0_shape)
    return boxes


def make_divisible(x, divisor):
    """
    Returns the nearest number that is divisible by the given divisor.

    Args:
        x (int): The number to make divisible.
        divisor (int | torch.Tensor): The divisor.

    Returns:
        (int): The nearest number divisible by the divisor.
    """
    if isinstance(divisor, torch.Tensor):
        divisor = int(divisor.max())  # to int
    return math.ceil(x / divisor) * divisor

from scipy.optimize import linear_sum_assignment
def IOU(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    interArea = max(0, xB - xA) * max(0, yB - yA)
    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    eps = 1e-7
    iou = interArea / float(boxAArea + boxBArea - interArea + eps)
    if torch.is_tensor(iou):   
        return iou.cpu()
    else:
        return iou

def xyxy_xywh(boxA):
    x_1, y_1, x_2, y_2 = boxA[0], boxA[1], boxA[2], boxA[3]
    x_c = (x_1 + x_2) / 2
    y_c = (y_1 + y_2) / 2
    w = x_2 - x_1
    h = y_2 - y_1
    return x_c, y_c, w, h

def xywh_xyxy(x,y,w,h):
    x1 = x - w/2
    y1 = y - h/2
    x2 = x + w/2
    y2 = y + h/2
    return x1, y1, x2, y2

def Hcost(boxA, boxB):
    iou = IOU(boxA, boxB)
    conf = abs(boxA[4] - boxB[4]) / (max(boxA[4], boxB[4]) + 1e-7)
    cost = iou + conf
    if torch.is_tensor(cost):   
        return cost.cpu()
    else:
        return cost


def nms(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=1,
        nc=0,
        max_time_img=0.05,
        max_nms=30000,
        max_wh=7680,
        before=False,
):
    device = prediction.device
    mps = 'mps' in device.type  # Apple MPS
    if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS
        prediction = prediction.cpu()
    bs = prediction.shape[0]  # batch size
    nc = nc or (prediction.shape[1] - 4)  # number of classes
    nm = prediction.shape[1] - nc - 4
    mi = 4 + nc  # mask start index
    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates

    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    time_limit = 0.5 + max_time_img * bs  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    prediction = prediction.transpose(-1, -2)  # shape(1,84,6300) to shape(1,6300,84)
    prediction[..., :4] = xywh2xyxy(prediction[..., :4])  # xywh to xyxy

    t = time.time()
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
            v[:, :4] = lb[:, 1:5]  # box
            v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Detections matrix nx6 (xyxy, conf, cls)
        box, cls, mask = x.split((4, nc, nm), 1)

        if multi_label:
            i, j = torch.where(cls > conf_thres)
            x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
        else:  # best class only
            conf, j = cls.max(1, keepdim=True)
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        if n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        i = i[:max_det]  # limit detections
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        if before:
            x = x[:, :5]
        output[xi] = x[i]
        if mps:
            output[xi] = output[xi].to(device)
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded
    return output

def non_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nc=0,  # number of classes (optional)
        max_time_img=0.05,
        max_nms=30000,
        max_wh=7680,
        cube=True,
        mod=False,
        training=True,
):
    """
    Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.

    Arguments:
        prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
            containing the predicted boxes, classes, and masks. The tensor should be in the format
            output by a model, such as YOLO.
        conf_thres (float): The confidence threshold below which boxes will be filtered out.
            Valid values are between 0.0 and 1.0.
        iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
            Valid values are between 0.0 and 1.0.
        classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
        agnostic (bool): If True, the model is agnostic to the number of classes, and all
            classes will be considered as one.
        multi_label (bool): If True, each box may have multiple labels.
        labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
            list contains the apriori labels for a given image. The list should be in the format
            output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
        max_det (int): The maximum number of boxes to keep after NMS.
        nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
        max_time_img (float): The maximum time (seconds) for processing one image.
        max_nms (int): The maximum number of boxes into torchvision.ops.nms().
        max_wh (int): The maximum box width and height in pixels

    Returns:
        (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
            shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
            (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
    """

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
    if isinstance(prediction, (list, tuple)):  # YOLOv8 model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output
    
    if cube:
        
        ### 实验二###
        # outputs = nms(prediction[:, 10:15],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=False)
        ### ###
        
        ### 实验四 ###
    #     output_t1 = nms(prediction[:, :5],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=True)
    #     output_t2 = nms(prediction[:, 5:10],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=True)
    #     output_t3 = nms(prediction[:, 10:15],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=False)
    #     assert len(output_t1) == len(output_t2) == len(output_t3)   # == batchsize
    #     for i in range(len(output_t1)):
    #         num = max(len(output_t1[i]), len(output_t2[i]), len(output_t3[i]))
    #         if num == 0:
    #             continue
            
    #         if len(output_t1[i]) < num:
    #             for j in range(num-len(output_t1[i])):
    #                 output_t1[i] = torch.cat((output_t1[i], torch.zeros(1,5).to(output_t1[i].device)), 0)
    #         if len(output_t2[i]) < num:
    #             for j in range(num-len(output_t2[i])):
    #                 output_t2[i] = torch.cat((output_t2[i], torch.zeros(1,5).to(output_t2[i].device)), 0)
    #         if len(output_t3[i]) < num:
    #             for j in range(num-len(output_t3[i])):
    #                 output_t3[i] = torch.cat((output_t3[i], torch.zeros(1,6).to(output_t3[i].device)), 0)
    #         cost_martix_23 = [[(1-IOU(a,b)) for a in output_t3[i]] for b in output_t2[i]]
    #         t2_ind, t3_ind = linear_sum_assignment(cost_martix_23)
    #         cost_martix_21 = [[(1-IOU(a,b)) for a in output_t1[i]] for b in output_t2[i]]
    #         t2_ind, t1_ind = linear_sum_assignment(cost_martix_21)
    #         assert len(t1_ind) == len(t2_ind) == len(t3_ind)
            
    #         for j in range(num):
    #             iou32 = IOU(output_t3[i][t3_ind[j]], output_t2[i][t2_ind[j]])
    #             iou21 = IOU(output_t2[i][t2_ind[j]], output_t1[i][t1_ind[j]])
    #             iou31 = IOU(output_t3[i][t3_ind[j]], output_t1[i][t1_ind[j]])
    #             # if iou32 > 0.75 and iou31 > 0.5:
    #             #     x3_c, y3_c, w3, h3 = xyxy_xywh(output_t3[i][t3_ind[j]])
    #             #     x2_c, y2_c, w2, h2 = xyxy_xywh(output_t2[i][t2_ind[j]])
    #             #     x1_c, y1_c, w1, h1 = xyxy_xywh(output_t1[i][t1_ind[j]])
    #             #     w = (3*w3 + 2*w2 + w1) / 6
    #             #     h = (3*h3 + 2*h2 + h1) / 6
    #             #     x3_1, y3_1, x3_2, y3_2 = xywh_xyxy(x3_c, y3_c, w, h)
    #             #     score = 1-(1-output_t3[i][t3_ind[j]][4])*(1-output_t2[i][t2_ind[j]][4])*(1-output_t1[i][t1_ind[j]][4])
    #             #     output_t3[i][t3_ind[j]][0], output_t3[i][t3_ind[j]][1], output_t3[i][t3_ind[j]][2], output_t3[i][t3_ind[j]][3], output_t3[i][t3_ind[j]][4] = x3_1, y3_1, x3_2, y3_2, score
                    
    #             # if iou21 > 0.75 and iou31 < 0.25 and iou32 < 0.25:
    #             #     x2_c, y2_c, w2, h2 = xyxy_xywh(output_t2[i][t2_ind[j]])
    #             #     x1_c, y1_c, w1, h1 = xyxy_xywh(output_t1[i][t1_ind[j]])
    #             #     w = (w2 + w1) / 2
    #             #     h = (h2 + h1) / 2
    #             #     x3_1, y3_1, x3_2, y3_2 = xywh_xyxy((x2_c+x1_c)/2, (y2_c+y1_c), w, h)
    #             #     score = 1-(1-output_t2[i][t2_ind[j]][4])*(1-output_t1[i][t1_ind[j]][4])
    #             #     output_t3[i][t3_ind[j]][0], output_t3[i][t3_ind[j]][1], output_t3[i][t3_ind[j]][2], output_t3[i][t3_ind[j]][3], output_t3[i][t3_ind[j]][4] = x3_1, y3_1, x3_2, y3_2, score

    #             # elif iou31 > 0.5 and iou32 < 0.5:
    #             #     x3_c, y3_c, w3, h3 = xyxy_xywh(output_t3[i][t3_ind[j]])
    #             #     x1_c, y1_c, w1, h1 = xyxy_xywh(output_t1[i][t1_ind[j]])
    #             #     w = (3*w3 + w1) / 4
    #             #     h = (3*h3 + h1) / 4
    #             #     x3_1, y3_1, x3_2, y3_2 = xywh_xyxy(x3_c, y3_c, w, h)
    #             #     score = 1-(1-output_t3[i][t3_ind[j]][4])*(1-output_t1[i][t1_ind[j]][4])
    #             #     output_t3[i][t3_ind[j]][0], output_t3[i][t3_ind[j]][1], output_t3[i][t3_ind[j]][2], output_t3[i][t3_ind[j]][3], output_t3[i][t3_ind[j]][4] = x3_1, y3_1, x3_2, y3_2, score
    #             #     n3 = n3 + 1
    #             # elif iou32 > 0.75 and iou31 < 0.25:
    #             #     x3_c, y3_c, w3, h3 = xyxy_xywh(output_t3[i][t3_ind[j]])
    #             #     x2_c, y2_c, w2, h2 = xyxy_xywh(output_t2[i][t2_ind[j]])
    #             #     w = (2*w3 + w2) / 3
    #             #     h = (2*h3 + h2) / 3
    #             #     x3_1, y3_1, x3_2, y3_2 = xywh_xyxy(x3_c, y3_c, w, h)
    #             #     score = 1-(1-output_t3[i][t3_ind[j]][4])*(1-output_t2[i][t2_ind[j]][4])
    #             #     output_t3[i][t3_ind[j]][0], output_t3[i][t3_ind[j]][1], output_t3[i][t3_ind[j]][2], output_t3[i][t3_ind[j]][3], output_t3[i][t3_ind[j]][4] = x3_1, y3_1, x3_2, y3_2, score
    #             #     n4 = n4 + 1
                
            
    #     outputs = output_t3
    #     ### ###
        
    #     # max_det=1
    #     # output_t1 = nms(prediction[:, :5],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=True)
    #     # output_t2 = nms(prediction[:, 5:10],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=True)
    #     # output_t3 = nms(prediction[:, 10:15],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=False)
    #     # #assert len(output_t1) == len(output_t2) == len(output_t3)
    #     # outputs = []
    #     # for i in range(max(len(output_t1),len(output_t2),len(output_t3))):
    #     #     output = torch.zeros(1,16).to(output_t1[0].device)
    #     #     if len(output_t1[i]) == 0:
    #     #         output_t1[i] = torch.zeros(1,5).to(output_t1[0].device)
    #     #     if len(output_t2[i]) == 0:
    #     #         output_t2[i] = torch.zeros(1,5).to(output_t1[0].device)
    #     #     if len(output_t3[i]) == 0:
    #     #         output_t3[i] = torch.zeros(1,6).to(output_t1[0].device)
    #     #     out = torch.cat((output_t1[i][0], output_t2[i][0], output_t3[i][0]), 0)
    #     #     output[0,:] = out
    #     #     outputs.append(output)
    
        ### 实验五 ###
        if mod and not training:
            output_t1 = nms(prediction[:, :5],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=True)
            output_t2 = nms(prediction[:, 5:10],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=True)
            output_t3 = nms(prediction[:, 10:15],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=False)
            # output_t4 = nms(prediction[:, 15:20],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=True)
            # output_t5 = nms(prediction[:, 20:25],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=True)
            # output_t6 = nms(prediction[:, 25:30],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=False)
            assert len(output_t1) == len(output_t2) == len(output_t3)   # == batchsize
            
            outputs = []
            for i in range(len(output_t1)):
                num = max(len(output_t1[i]), len(output_t2[i]), len(output_t3[i]))
                if num == 0:
                    out = torch.empty(0,16).to(output_t6[i].device)
                    outputs.append(out)
                    continue
                
                if len(output_t1[i]) < num:
                    for j in range(num-len(output_t1[i])):
                        if len(output_t1[i]) == 0:
                            output_t1[i] = torch.zeros(1,5).to(output_t1[i].device)
                        else:
                            output_t1[i] = torch.cat((output_t1[i], torch.zeros(1,5).to(output_t1[i].device)), 0)
                if len(output_t2[i]) < num:
                    for j in range(num-len(output_t2[i])):
                        if len(output_t2[i]) == 0:
                            output_t2[i] = torch.zeros(1,5).to(output_t2[i].device)
                        else:
                            output_t2[i] = torch.cat((output_t2[i], torch.zeros(1,5).to(output_t2[i].device)), 0)
                if len(output_t3[i]) < num:
                    for j in range(num-len(output_t3[i])):
                        if len(output_t3[i]) == 0:
                            output_t3[i] = torch.zeros(1,6).to(output_t3[i].device)
                        else:
                            output_t3[i] = torch.cat((output_t3[i], torch.zeros(1,6).to(output_t3[i].device)), 0)
                # if len(output_t4[i]) < num:
                #     for j in range(num-len(output_t4[i])):
                #         if len(output_t4[i]) == 0:
                #             output_t4[i] = torch.zeros(1,6).to(output_t4[i].device)
                #         else:
                #             output_t4[i] = torch.cat((output_t4[i], torch.zeros(1,6).to(output_t4[i].device)), 0)
                # if len(output_t5[i]) < num:
                #     for j in range(num-len(output_t5[i])):
                #         if len(output_t5[i]) == 0:
                #             output_t5[i] = torch.zeros(1,6).to(output_t5[i].device)
                #         else:
                #             output_t5[i] = torch.cat((output_t5[i], torch.zeros(1,6).to(output_t5[i].device)), 0)
                # if len(output_t6[i]) < num:
                #     for j in range(num-len(output_t6[i])):
                #         if len(output_t6[i]) == 0:
                #             output_t6[i] = torch.zeros(1,6).to(output_t6[i].device)
                #         else:
                #             output_t6[i] = torch.cat((output_t6[i], torch.zeros(1,6).to(output_t6[i].device)), 0)
                        
                cost_martix_32 = [[(Hcost(a,b)) for a in output_t2[i]] for b in output_t3[i]]
                t3_ind, t32_ind = linear_sum_assignment(cost_martix_32)
                cost_martix_21 = [[(Hcost(a,b)) for a in output_t1[i]] for b in output_t2[i]]
                t2_ind, t1_ind = linear_sum_assignment(cost_martix_21)
                
                # output = torch.zeros(num,16).to(output_t1[0].device)
                # num_ = 0
                output = []
                for j in range(num):
                    if output_t3[i][j][4] == 0:
                        if output_t2[i][t32_ind[j]][4] > 0.1 and output_t1[i][t1_ind[t32_ind[j]]][4] > 0.1:
                            iou12 = IOU(output_t1[i][t1_ind[t32_ind[j]]], output_t2[i][t32_ind[j]])
                            if iou12 > 0.5:
                                output_t3[i][j][:5] = (output_t1[i][t1_ind[t32_ind[j]]] + output_t2[i][t32_ind[j]]) / 2
                        else:
                            continue
                    out = torch.cat((output_t1[i][t1_ind[t32_ind[j]]], output_t2[i][t32_ind[j]], output_t3[i][t3_ind[j]]), 0)
                    output.append(out)
                # for j in range(num):
                #     out = torch.cat((output_t1[i][t1_ind[j]], output_t2[i][t2_ind[j]], output_t3[i][t3_ind[j]], output_t4[i][t4_ind[j]], output_t5[i][t5_ind[j]], output_t6[i][t6_ind[j]]), 0)
                if len(output) == 0:
                    output = torch.empty((0,16)).to(output_t3[i].device)
                else:
                    output = torch.stack(output)
                outputs.append(output)
        else:
            outputs = []
            output = nms(prediction[:, 10:15],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=False)
            for i in range(len(output)):
                outputs.append(torch.cat((output[i][:,0:5].repeat(1,3),output[i][:,5:6]),1))

    else:
        outputs = nms(prediction[:, :5],conf_thres,iou_thres,classes,agnostic,multi_label,labels,max_det,nc,max_time_img,max_nms,max_wh,before=False)

    return outputs


def clip_boxes(boxes, shape):
    """
    It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the
    shape

    Args:
      boxes (torch.Tensor): the bounding boxes to clip
      shape (tuple): the shape of the image
    """
    if isinstance(boxes, torch.Tensor):  # faster individually
        boxes[..., 0].clamp_(0, shape[1])  # x1
        boxes[..., 1].clamp_(0, shape[0])  # y1
        boxes[..., 2].clamp_(0, shape[1])  # x2
        boxes[..., 3].clamp_(0, shape[0])  # y2
    else:  # np.array (faster grouped)
        boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1])  # x1, x2
        boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0])  # y1, y2


def clip_coords(coords, shape):
    """
    Clip line coordinates to the image boundaries.

    Args:
        coords (torch.Tensor | numpy.ndarray): A list of line coordinates.
        shape (tuple): A tuple of integers representing the size of the image in the format (height, width).

    Returns:
        (None): The function modifies the input `coordinates` in place, by clipping each coordinate to the image boundaries.
    """
    if isinstance(coords, torch.Tensor):  # faster individually
        coords[..., 0].clamp_(0, shape[1])  # x
        coords[..., 1].clamp_(0, shape[0])  # y
    else:  # np.array (faster grouped)
        coords[..., 0] = coords[..., 0].clip(0, shape[1])  # x
        coords[..., 1] = coords[..., 1].clip(0, shape[0])  # y


def scale_image(masks, im0_shape, ratio_pad=None):
    """
    Takes a mask, and resizes it to the original image size

    Args:
      masks (torch.Tensor): resized and padded masks/images, [h, w, num]/[h, w, 3].
      im0_shape (tuple): the original image shape
      ratio_pad (tuple): the ratio of the padding to the original image.

    Returns:
      masks (torch.Tensor): The masks that are being returned.
    """
    # Rescale coordinates (xyxy) from im1_shape to im0_shape
    im1_shape = masks.shape
    if im1_shape[:2] == im0_shape[:2]:
        return masks
    if ratio_pad is None:  # calculate from im0_shape
        gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1])  # gain  = old / new
        pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]
    top, left = int(pad[1]), int(pad[0])  # y, x
    bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])

    if len(masks.shape) < 2:
        raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
    masks = masks[top:bottom, left:right]
    # masks = masks.permute(2, 0, 1).contiguous()
    # masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0]
    # masks = masks.permute(1, 2, 0).contiguous()
    masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
    if len(masks.shape) == 2:
        masks = masks[:, :, None]

    return masks


def xyxy2xywh(x):
    """
    Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.

    Args:
        x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
    Returns:
       y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = (x[..., 0] + x[..., 2]) / 2  # x center
    y[..., 1] = (x[..., 1] + x[..., 3]) / 2  # y center
    y[..., 2] = x[..., 2] - x[..., 0]  # width
    y[..., 3] = x[..., 3] - x[..., 1]  # height
    return y


def xywh2xyxy(x):
    """
    Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
    top-left corner and (x2, y2) is the bottom-right corner.

    Args:
        x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
    Returns:
        y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
    y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
    y[..., 2] = x[..., 0] + x[..., 2] / 2  # bottom right x
    y[..., 3] = x[..., 1] + x[..., 3] / 2  # bottom right y
    return y


def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
    """
    Convert normalized bounding box coordinates to pixel coordinates.

    Args:
        x (np.ndarray | torch.Tensor): The bounding box coordinates.
        w (int): Width of the image. Defaults to 640
        h (int): Height of the image. Defaults to 640
        padw (int): Padding width. Defaults to 0
        padh (int): Padding height. Defaults to 0
    Returns:
        y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
            x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw  # top left x
    y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh  # top left y
    y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw  # bottom right x
    y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh  # bottom right y
    return y


def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
    """
    Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format.
    x, y, width and height are normalized to image dimensions

    Args:
        x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
        w (int): The width of the image. Defaults to 640
        h (int): The height of the image. Defaults to 640
        clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
        eps (float): The minimum value of the box's width and height. Defaults to 0.0
    Returns:
        y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
    """
    if clip:
        clip_boxes(x, (h - eps, w - eps))  # warning: inplace clip
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w  # x center
    y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h  # y center
    y[..., 2] = (x[..., 2] - x[..., 0]) / w  # width
    y[..., 3] = (x[..., 3] - x[..., 1]) / h  # height
    return y


def xyn2xy(x, w=640, h=640, padw=0, padh=0):
    """
    Convert normalized coordinates to pixel coordinates of shape (n,2)

    Args:
        x (np.ndarray | torch.Tensor): The input tensor of normalized bounding box coordinates
        w (int): The width of the image. Defaults to 640
        h (int): The height of the image. Defaults to 640
        padw (int): The width of the padding. Defaults to 0
        padh (int): The height of the padding. Defaults to 0
    Returns:
        y (np.ndarray | torch.Tensor): The x and y coordinates of the top left corner of the bounding box
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = w * x[..., 0] + padw  # top left x
    y[..., 1] = h * x[..., 1] + padh  # top left y
    return y


def xywh2ltwh(x):
    """
    Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.

    Args:
        x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
    Returns:
        y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    return y


def xyxy2ltwh(x):
    """
    Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right

    Args:
      x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
    Returns:
      y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 2] = x[:, 2] - x[:, 0]  # width
    y[:, 3] = x[:, 3] - x[:, 1]  # height
    return y


def ltwh2xywh(x):
    """
    Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center

    Args:
      x (torch.Tensor): the input tensor
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = x[:, 0] + x[:, 2] / 2  # center x
    y[:, 1] = x[:, 1] + x[:, 3] / 2  # center y
    return y


def ltwh2xyxy(x):
    """
    It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right

    Args:
      x (np.ndarray | torch.Tensor): the input image

    Returns:
      y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 2] = x[:, 2] + x[:, 0]  # width
    y[:, 3] = x[:, 3] + x[:, 1]  # height
    return y


def segments2boxes(segments):
    """
    It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)

    Args:
      segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates

    Returns:
      (np.ndarray): the xywh coordinates of the bounding boxes.
    """
    boxes = []
    for s in segments:
        x, y = s.T  # segment xy
        boxes.append([x.min(), y.min(), x.max(), y.max()])  # cls, xyxy
    return xyxy2xywh(np.array(boxes))  # cls, xywh


def resample_segments(segments, n=1000):
    """
    Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.

    Args:
      segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
      n (int): number of points to resample the segment to. Defaults to 1000

    Returns:
      segments (list): the resampled segments.
    """
    for i, s in enumerate(segments):
        s = np.concatenate((s, s[0:1, :]), axis=0)
        x = np.linspace(0, len(s) - 1, n)
        xp = np.arange(len(s))
        segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)],
                                     dtype=np.float32).reshape(2, -1).T  # segment xy
    return segments


def crop_mask(masks, boxes):
    """
    It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box

    Args:
      masks (torch.Tensor): [h, w, n] tensor of masks
      boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form

    Returns:
      (torch.Tensor): The masks are being cropped to the bounding box.
    """
    n, h, w = masks.shape
    x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(n,1,1)
    r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]  # rows shape(1,1,w)
    c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]  # cols shape(1,h,1)

    return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))


def process_mask_upsample(protos, masks_in, bboxes, shape):
    """
    It takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
    quality but is slower.

    Args:
      protos (torch.Tensor): [mask_dim, mask_h, mask_w]
      masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
      bboxes (torch.Tensor): [n, 4], n is number of masks after nms
      shape (tuple): the size of the input image (h,w)

    Returns:
      (torch.Tensor): The upsampled masks.
    """
    c, mh, mw = protos.shape  # CHW
    masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
    masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0]  # CHW
    masks = crop_mask(masks, bboxes)  # CHW
    return masks.gt_(0.5)


def process_mask(protos, masks_in, bboxes, shape, upsample=False):
    """
    Apply masks to bounding boxes using the output of the mask head.

    Args:
        protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
        masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
        bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
        shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
        upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.

    Returns:
        (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
            are the height and width of the input image. The mask is applied to the bounding boxes.
    """

    c, mh, mw = protos.shape  # CHW
    ih, iw = shape
    masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)  # CHW

    downsampled_bboxes = bboxes.clone()
    downsampled_bboxes[:, 0] *= mw / iw
    downsampled_bboxes[:, 2] *= mw / iw
    downsampled_bboxes[:, 3] *= mh / ih
    downsampled_bboxes[:, 1] *= mh / ih

    masks = crop_mask(masks, downsampled_bboxes)  # CHW
    if upsample:
        masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0]  # CHW
    return masks.gt_(0.5)


def process_mask_native(protos, masks_in, bboxes, shape):
    """
    It takes the output of the mask head, and crops it after upsampling to the bounding boxes.

    Args:
      protos (torch.Tensor): [mask_dim, mask_h, mask_w]
      masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
      bboxes (torch.Tensor): [n, 4], n is number of masks after nms
      shape (tuple): the size of the input image (h,w)

    Returns:
      masks (torch.Tensor): The returned masks with dimensions [h, w, n]
    """
    c, mh, mw = protos.shape  # CHW
    masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
    gain = min(mh / shape[0], mw / shape[1])  # gain  = old / new
    pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2  # wh padding
    top, left = int(pad[1]), int(pad[0])  # y, x
    bottom, right = int(mh - pad[1]), int(mw - pad[0])
    masks = masks[:, top:bottom, left:right]

    masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0]  # CHW
    masks = crop_mask(masks, bboxes)  # CHW
    return masks.gt_(0.5)


def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False):
    """
    Rescale segment coordinates (xyxy) from img1_shape to img0_shape

    Args:
      img1_shape (tuple): The shape of the image that the coords are from.
      coords (torch.Tensor): the coords to be scaled
      img0_shape (tuple): the shape of the image that the segmentation is being applied to
      ratio_pad (tuple): the ratio of the image size to the padded image size.
      normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False

    Returns:
      coords (torch.Tensor): the segmented image.
    """
    if ratio_pad is None:  # calculate from img0_shape
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]

    coords[..., 0] -= pad[0]  # x padding
    coords[..., 1] -= pad[1]  # y padding
    coords[..., 0] /= gain
    coords[..., 1] /= gain
    clip_coords(coords, img0_shape)
    if normalize:
        coords[..., 0] /= img0_shape[1]  # width
        coords[..., 1] /= img0_shape[0]  # height
    return coords


def masks2segments(masks, strategy='largest'):
    """
    It takes a list of masks(n,h,w) and returns a list of segments(n,xy)

    Args:
      masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
      strategy (str): 'concat' or 'largest'. Defaults to largest

    Returns:
      segments (List): list of segment masks
    """
    segments = []
    for x in masks.int().cpu().numpy().astype('uint8'):
        c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
        if c:
            if strategy == 'concat':  # concatenate all segments
                c = np.concatenate([x.reshape(-1, 2) for x in c])
            elif strategy == 'largest':  # select largest segment
                c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
        else:
            c = np.zeros((0, 2))  # no segments found
        segments.append(c.astype('float32'))
    return segments


def clean_str(s):
    """
    Cleans a string by replacing special characters with underscore _

    Args:
      s (str): a string needing special characters replaced

    Returns:
      (str): a string with special characters replaced by an underscore _
    """
    return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
